/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package studio.raptor.ddal.core.connection;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.pool2.PooledObjectFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import studio.raptor.ddal.common.event.NativeEventBus;
import studio.raptor.ddal.common.exception.ConfigException;
import studio.raptor.ddal.common.exception.ConfigException.Code;
import studio.raptor.ddal.common.exception.GenericException;
import studio.raptor.ddal.common.exception.code.CommonErrorCodes;
import studio.raptor.ddal.common.util.StringUtil;
import studio.raptor.ddal.config.common.ConfigConstant;
import studio.raptor.ddal.config.common.ConfigTools;
import studio.raptor.ddal.config.config.ShardConfig;
import studio.raptor.ddal.config.config.SystemProperties;
import studio.raptor.ddal.config.model.shard.DataSource;
import studio.raptor.ddal.config.model.shard.DataSourceGroup;
import studio.raptor.ddal.config.model.shard.DataSourceGroups;
import studio.raptor.ddal.config.model.shard.PhysicalDBCluster;
import studio.raptor.ddal.core.connection.jdbc.JdbcBackendConnection;
import studio.raptor.ddal.core.connection.jdbc.JdbcConnectionFactory;
import studio.raptor.ddal.core.connection.jdbc.PooledJdbcConnectionFactory;
import studio.raptor.ddal.core.constants.ConnectionRwMode;
import studio.raptor.ddal.core.constants.DataSourceAccessLevel;

/**
 * Backend connection pool manager
 *
 * @author Sam
 * @since 3.0.0
 */
public enum BackendDataSourceManager {
  INSTANCE;

  /**
   * 后端连接池组对象模型。
   */
  private transient Map<String, Map<String, List<BackendDataSource>>> shardDataSourceGroup;
  private static Logger log = LoggerFactory.getLogger(BackendDataSourceManager.class);

  /**
   * 后端连接池管理器构造器。
   *
   * @throws GenericException 初始化连接时连接数据库发生的异常
   */
  BackendDataSourceManager() throws GenericException {
    shardDataSourceGroup = new HashMap<>();
    try {
      initPool();
      // 初始化成功之后注册连接池配置变化事件监听
      NativeEventBus.get().register(new DSChangeEventListener());
    } catch (SQLException sqlException) {
      throw new GenericException(CommonErrorCodes.COMMON_501, sqlException);
    }
  }

  public static void warmupBackendConnectionPool() {
    // do nothing
  }

  /**
   * 从数据源里获取后端物理连接。
   *
   * 物理连接的autoCommit属性是jdbc事务的关键。
   *
   * @param groupName 数据源组名
   * @param isReadOnly 是否只读
   * @return see {@link BackendConnection}
   * @throws SQLException SQLException
   */
  public static BackendConnection getBackendConnection(String groupName, boolean isReadOnly,
      boolean autoCommit) throws SQLException {
    Map<String, List<BackendDataSource>> groupDataSources = getGroupDataSource(groupName);

    List<BackendDataSource> backendDataSources =
        isReadOnly ? groupDataSources.get(ConnectionRwMode.R.getCode())
            : groupDataSources.get(ConnectionRwMode.W.getCode());

    checkBackendDataSources(backendDataSources);

    BackendDataSource dataSource = backendDataSources.get(0);
    checkDataSourceAccessLevel(isReadOnly, dataSource);

    BackendConnection conn = dataSource.getConnection();
    conn.setAutoCommit(autoCommit);
    return conn;
  }

  private static void checkBackendDataSources(List<BackendDataSource> backendDataSources) {
    if (null == backendDataSources || backendDataSources.isEmpty()) {
      throw new RuntimeException("No BackendDataSource available for now, try it later.");
    }
  }

  private static void checkDataSourceAccessLevel(boolean isReadOnly, BackendDataSource dataSource) {
    int level = dataSource.getAccessLevel().getLevel() & DataSourceAccessLevel.MASK;
    // 禁止访问连接校验
    if (level == DataSourceAccessLevel.BLOCK.getLevel()) {
      throw new GenericException(CommonErrorCodes.COMMON_512);
    }
    // AccessLevel是只读的数据源限制借写连接。
    if (level == DataSourceAccessLevel.R.getLevel()) {
      if (!isReadOnly) {
        throw new GenericException(CommonErrorCodes.COMMON_513);
      }
    }
  }

  static Map<String, List<BackendDataSource>> getGroupDataSource(String groupName) {
    return INSTANCE.getShardDataSourceGroup(groupName);
  }

  private DataSourceAccessLevel reflectAccessLevel(String accessLevelTexture) {
    DataSourceAccessLevel accessLevel;
    if (StringUtil.isEmpty(accessLevelTexture)) {
      accessLevel = DataSourceAccessLevel.RW;
    } else {
      accessLevel = DataSourceAccessLevel.textureOf(accessLevelTexture);
    }
    return accessLevel;
  }

  /**
   * init connection pool
   *
   * @throws SQLException Database access error
   */
  @SuppressWarnings("unchecked")
  private void initPool() throws SQLException {
    ShardConfig shardConfig = ShardConfig.getInstance();
    DataSourceGroups dataSourceGroups = shardConfig.getDataSourceGroups();
    for (DataSourceGroup dataSourceGroup : dataSourceGroups) {
      PhysicalDBCluster cluster = shardConfig
          .getPhysicalDBCluster(dataSourceGroup.getRelaCluster());
      DataSource[] dataSources = dataSourceGroup.getDataSources();

      for (DataSource dataSource : dataSources) {
        PhysicalDBCluster.DBInstance dbInstance = cluster.get(dataSource.getDbInstName());
        ConnectionRwMode dbInstanceRwMode = ConnectionRwMode.fromCode(dbInstance.getRw());
        switch (dbInstanceRwMode) {
          case W:
          case R:
                        /* config driver */
            BackendDataSource bcp;
            if ("jdbc".equals(dataSource.getDbDriver())) {
              String connectUrl = "";
              String driverClassName = "";
              if ("mysql".equals(cluster.getType())) {
                connectUrl = "jdbc:mysql://" + dbInstance.getHostname() + ":" + dbInstance.getPort();
                driverClassName = "com.mysql.jdbc.Driver";
              } else if ("oracle".equalsIgnoreCase(cluster.getType())) {
                connectUrl = String.format("jdbc:oracle:thin:@%s:%s:%s", dbInstance.getHostname(), dbInstance.getPort(), dbInstance.getSid());
                driverClassName = "oracle.jdbc.OracleDriver";
              } else if ("h2".equals(cluster.getType())) {
                connectUrl = String.format("jdbc:h2:%s/%s;IFEXISTS=TRUE;FILE_LOCK=SOCKET", dbInstance.getH2dir(), dbInstance.getH2db());
                driverClassName = "org.h2.Driver";
              }
              ConnectionFactory<JdbcBackendConnection> jdbcPhysicalConnFactory =
                  new JdbcConnectionFactory(driverClassName, connectUrl, dataSource.getUser(),
                      decryptPassword(dataSource.getPwd()));
              PooledObjectFactory jdbcPooledObjectFactory =
                  new PooledJdbcConnectionFactory(jdbcPhysicalConnFactory);

              bcp = new BackendDataSource(
                  dataSource.getDbInstName(),
                  reflectAccessLevel(dataSource.getAccessLevel()),
                  new PooledBackendConnectionFactory(jdbcPooledObjectFactory),
                  buildPoolConfig(dataSource.getParams())
              );

              // 设置数据源的只读属性。
              bcp.setReadOnlyPool(dbInstanceRwMode == ConnectionRwMode.R);

            } else {
              throw new GenericException(CommonErrorCodes.COMMON_500, "", dataSource.getDbDriver());
            }
            List<BackendDataSource> backendConnectionPools =
                getRwDataSources(dataSourceGroup.getName(),
                    ConnectionRwMode.fromCode(dbInstance.getRw()));
            backendConnectionPools.add(bcp);
        }
      }
    }
  }

  private String decryptPassword(String cipherText) {
    String password = cipherText;
    if (!"true"
        .equalsIgnoreCase(SystemProperties.getInstance().get(ConfigConstant.PROP_KEY_CONFIG_DECRYPT_ENABLED))) {
      return password;
    }
    ConfigTools.loadSysDecryptKey();
    try {
      password = ConfigTools.decrypt(ConfigTools.getSysDecryptKey(), cipherText);
    } catch (Exception e) {
      throw ConfigException.create(Code.DECRYPT_PASSWORD_ERROR);
    }
    return password;
  }

  /**
   * 从配置文件中读取连接池的配置参数。
   *
   * @param params 参数列表
   * @return 连接池参数配置对象。
   */
  private BackendConnectionPoolConfig buildPoolConfig(Map<String, String> params) {
    BackendConnectionPoolConfig poolConfig = new BackendConnectionPoolConfig();
    List<BackendConnectionPoolConfigParam> allConfigParams = Arrays
        .asList(BackendConnectionPoolConfigParam.values());

    List<String> allParamNames = new ArrayList<>(allConfigParams.size());
    for (BackendConnectionPoolConfigParam configParam : allConfigParams) {
      allParamNames.add(configParam.paramName);
    }
    for (Map.Entry<String, String> param : params.entrySet()) {
      BackendConnectionPoolConfigParam configParam = BackendConnectionPoolConfigParam
          .findByParamName(param.getKey());
      if (allParamNames.contains(param.getKey()) && null != configParam) {
        Object paramValue = getPoolParamNotNull(params, configParam.paramName,
            configParam.valueType);
        String setMethodName = "set" + param.getKey().substring(0, 1).toUpperCase() + param.getKey()
            .substring(1, param.getKey().length());
        try {
          // reflect to set pool param
          Method setMethod = BackendConnectionPoolConfig.class
              .getMethod(setMethodName, configParam.valueType);
          setMethod.invoke(poolConfig, paramValue);
        } catch (NoSuchMethodException exception) {
          log.error("No set method found for pool param [{}]", param.getKey());
        } catch (IllegalAccessException exception) {
          log.error("Illegal access to object {}", "BackendConnectionPoolConfig");
        } catch (InvocationTargetException exception) {
          log.error("Invocation of set config param value failed.");
        }
      } else {
        log.info("Unsupported pool param [{}], which has been ignored.", param.getKey());
      }
    }
    return poolConfig;
  }

  /**
   * 读取连接池参数，这个方法只读取整形的参数。如果参数没配置，或者配置的
   * 不是整形，则报对应的异常。
   *
   * @param params 配置文件中读取的连接池参数
   * @param paramName 连接池参数名
   * @return 整形参数值
   */
  private Object getPoolParamNotNull(Map<String, String> params, String paramName,
      Class<?> paramType) {
    if (!params.containsKey(paramName)) {
      throw new GenericException(CommonErrorCodes.COMMON_503, "", paramName);
    }
    String paramValue = params.get(paramName);
    if (StringUtil.isEmpty(paramValue)) {
      throw new GenericException(CommonErrorCodes.COMMON_504, "", paramValue, paramName);
    }
    if (paramType == long.class) {
      return Long.parseLong(paramValue);
    } else if (paramType == int.class) {
      return Integer.parseInt(paramValue);
    } else if (paramType == boolean.class) {
      return Boolean.valueOf(paramValue);
    } else if (paramType == String.class) {
      return String.valueOf(paramValue);
    } else {
      throw new IllegalArgumentException(String.format("Unsupported parameter type %s", paramType));
    }
  }


  /**
   * 获取数据源组的读写连接池
   *
   * @param groupName 数据源组
   * @return 读写连接池
   */
  private Map<String, List<BackendDataSource>> getShardDataSourceGroup(String groupName) {
    return shardDataSourceGroup.get(groupName);
  }

  /**
   * 获取数据源组中某种读写模式的连接池。在physicalDBCluster节点下对于一种
   * 读写模式可能会配置多个DbInstance，所以这个方法返回多个后端连接池。
   *
   * @param groupName datasource group name
   * @param rwMode read-write mode，{@link ConnectionRwMode}
   * @return 连接池组，在physicalDBCluster节点下
   */
  private List<BackendDataSource> getRwDataSources(String groupName, ConnectionRwMode rwMode) {
    Map<String, List<BackendDataSource>> shardDataSources = getShardDataSourceGroup(groupName);
    if (null == shardDataSources) {
      shardDataSources = new HashMap<>();
      this.shardDataSourceGroup.put(groupName, shardDataSources);
    }
    List<BackendDataSource> rwModeDataSources = shardDataSources.get(rwMode.getCode());
    if (null == rwModeDataSources) {
      rwModeDataSources = new ArrayList<>();
      shardDataSources.put(rwMode.getCode(), rwModeDataSources);
    }
    return rwModeDataSources;
  }
}
